{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 极其简单的想法\n",
    "\n",
    "用训练数据来计算数字图像的平均暗度,0, 1, 2, . . . , 9。当有一幅新的图像\n",
    "呈现,我们先计算图像的暗度,然后猜测它接近哪个数字的平均暗度。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 导入库\n",
    "from collections import defaultdict\n",
    "import mnist_loader\n",
    "\n",
    "#加载数据\n",
    "training_data, validation, test_data = mnist_loader.load_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 计算数字图像的平均暗度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def avg_darknesses(training_data):\n",
    "    \"\"\"计算数字图像的平均暗度\n",
    "    \n",
    "        输入： 训练数据集\n",
    "                training_data[0]: 为图像的数据\n",
    "                training_data[1]: 对应的数字\n",
    "        \n",
    "        返回值： 键值为0~9的字典，对应的值为数据图像的平均暗度\"\"\"\n",
    "    digit_counts = defaultdict(int)\n",
    "    darknesses = defaultdict(float)\n",
    "    for image, digit in zip(training_data[0], training_data[1]):\n",
    "        digit_counts[digit] += 1\n",
    "        darknesses[digit] += sum(image)\n",
    "    avgs = defaultdict(float)\n",
    "    for digit, n in digit_counts.items():\n",
    "        avgs[digit] = darknesses[digit] / n\n",
    "    return avgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 计算数字图像的平均暗度\n",
    "avgs = avg_darknesses(training_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def guess_digit(image, avgs):\n",
    "    \"\"\"根据输入的图像与哪个数字的平均暗度最接近，来预测图像的数字\n",
    "    \n",
    "        输入： 1.待预测的数字图像 2.数字图像的平均暗度\n",
    "        \n",
    "        返回： 预测图像的数字\"\"\"\n",
    "    darkness = sum(image)\n",
    "    distances = {k: abs(v-darkness) for k,v in avgs.items()}\n",
    "    return min(distances, key=distances.get)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Baseline classifier using average darkness of image.\n",
      "2225 of 10000 values correct.\n"
     ]
    }
   ],
   "source": [
    "num_correct = sum(int(guess_digit(image, avgs) == digit) for image, digit in zip(test_data[0], test_data[1]))\n",
    "\n",
    "print(\"Baseline classifier using average darkness of image.\")\n",
    "print(\"%s of %s values correct.\" % (num_correct, len(test_data[1])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 总结"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以发现这是个很天真的想法，虽然比随机地猜测有了很大的改进,能取得 10000 测试图像中 2225 的精确度,即22.25%，但效果仍然不能让人满意。因此，需要努力找到其他方法，以获得更高的精确度。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
